#include "pwm_lib.h"

void PWM_Initialise (pwmDevice *dev, TIM_HandleTypeDef *handle, uint32_t ch, float init_frq, float init_dc) {

    dev -> handle     = handle;
    dev -> channel    = ch;
    dev -> frequency  = init_frq;
    dev -> duty_cycle = init_dc;

    uint32_t psc, arr;

    PWM_ComputePscArr (handle -> Instance, init_frq, &psc, &arr);

    dev -> prescaler = psc;

    handle -> Instance -> PSC = psc;
    handle -> Instance -> ARR = arr;
    handle -> Instance -> EGR = TIM_EGR_UG;

    PWM_Set_DC (dev, init_dc, ch);

};

void PWM_Set_FR (pwmDevice *dev, float frequency) {

    if (frequency < 1.0f) frequency = 1.0f;

    uint32_t psc, arr;

    PWM_ComputePscArr (dev -> handle -> Instance, frequency, &psc, &arr);

    // Stop PWM while we reconfigure

    PWM_Stop (dev, dev -> channel);

    dev -> prescaler = psc;
    dev -> frequency = frequency;

    dev -> handle -> Instance -> PSC = psc;
    dev -> handle -> Instance -> ARR = arr;
    dev -> handle -> Instance -> EGR = TIM_EGR_UG;

    // Re-apply duty with new ARR

    PWM_Set_DC (dev, dev -> duty_cycle, dev -> channel);

    // Optionally restart here if you want:
    // PWM_Start (dev, dev -> channel);

};

void PWM_Set_DC (pwmDevice *dev, float duty, uint32_t ch) {

    if (duty > 100.0f) duty = 100.0f;
    if (duty < 0.0f)   duty = 0.0f;

    uint32_t arr = dev -> handle -> Instance -> ARR + 1;

    uint32_t ccr = (uint32_t) ((duty / 100.0f) * (float) arr);

    switch (ch) {

        case TIM_CHANNEL_1: dev -> handle -> Instance -> CCR1 = ccr; break;
        case TIM_CHANNEL_2: dev -> handle -> Instance -> CCR2 = ccr; break;
        case TIM_CHANNEL_3: dev -> handle -> Instance -> CCR3 = ccr; break;
        case TIM_CHANNEL_4: dev -> handle -> Instance -> CCR4 = ccr; break;

        default: break;

    };

    dev -> duty_cycle = duty;

};

void PWM_Start (pwmDevice *dev, uint32_t ch) { HAL_TIM_PWM_Start (dev -> handle, ch); };
void PWM_Stop  (pwmDevice *dev, uint32_t ch) { HAL_TIM_PWM_Stop  (dev -> handle, ch); };

float PWM_GetTIM_DC (pwmDevice *dev) {

    uint32_t arr = dev -> handle -> Instance ->  ARR + 1;
    uint32_t ccr;

    switch (dev -> channel) {

        case TIM_CHANNEL_1: ccr = dev -> handle -> Instance -> CCR1; break;
        case TIM_CHANNEL_2: ccr = dev -> handle -> Instance -> CCR2; break;
        case TIM_CHANNEL_3: ccr = dev -> handle -> Instance -> CCR3; break;
        case TIM_CHANNEL_4: ccr = dev -> handle -> Instance -> CCR4; break;

        default: return 0.0f;

    };

    return (float) ccr * 100.0f / arr;

};

float PWM_GetTIM_UsrFreq (pwmDevice *dev) {

    uint32_t timer_clk = PWM_GetTIM_BaseFreq (dev -> handle -> Instance);

    uint32_t psc = dev -> handle -> Instance -> PSC;
    uint32_t arr = dev -> handle -> Instance -> ARR;

    uint64_t ticks = (uint64_t) (psc + 1u) * (uint64_t) (arr + 1u);

    if (ticks == 0u) return 0.0f;

    return (float) ((double) timer_clk / (double) ticks);

};

uint32_t PWM_GetTIM_BaseFreq (TIM_TypeDef *TIMx) {

    uint32_t pclk;
    uint32_t timer_clock;

    if (TIMx == TIM1 || TIMx == TIM8) {

        // Timers on APB2

        pclk = HAL_RCC_GetPCLK2Freq ();

        if ((RCC -> CFGR & RCC_CFGR_PPRE2) != RCC_CFGR_PPRE2_DIV1) {

            timer_clock = pclk * 2U;

        } else {

            timer_clock = pclk;

        };

    } else {

        // Timers on APB1

        pclk = HAL_RCC_GetPCLK1Freq ();

        if ((RCC->CFGR & RCC_CFGR_PPRE1) != RCC_CFGR_PPRE1_DIV1) {

            timer_clock = pclk * 2U;

        } else {

            timer_clock = pclk;

        };

    };

    return timer_clock;

};

uint8_t PWM_Is32bitTimer (TIM_TypeDef *TIMx) {

    return (TIMx == TIM2) || (TIMx == TIM5);

};

void PWM_ComputePscArr (TIM_TypeDef *TIMx, float frequency, uint32_t *psc_out, uint32_t *arr_out) {

    if (frequency < 1.0f) frequency = 1.0f;

    uint32_t timer_clk = PWM_GetTIM_BaseFreq (TIMx);
    uint8_t  is32      = PWM_Is32bitTimer    (TIMx);
    uint32_t max_arr   = is32 ? 0xFFFFFFFFu : 0xFFFFu;

    // Total timer ticks per PWM period at PSC = 0

    double total_ticks = (double) timer_clk / (double) frequency;

    if (total_ticks < 1.0) total_ticks = 1.0;

    uint32_t best_psc = 0;
    uint32_t best_arr = 0;

    // Brute-force PSC (0..65535) until ARR fits in timer width

    for (uint32_t psc = 0; psc <= 0xFFFFu; psc ++) {

        double arr_f = total_ticks / (double) (psc + 1u);

        if (arr_f <= (double) max_arr + 1.0) {

            if (arr_f < 1.0) arr_f = 1.0;

            best_psc = psc;
            best_arr = (uint32_t) arr_f - 1u;

            break;

        };

    };

    // Failsafe (should basically never hit for sane frequencies)

    if (best_arr == 0) {

        best_psc = 0xFFFFu;
        best_arr = max_arr;

    };

    *psc_out = best_psc;
    *arr_out = best_arr;

};
